import cv2
import torch
import lietorch
import numpy as np
import warnings
warnings.filterwarnings("ignore")
import torch.nn.functional as F
import geom.projective_ops as pops

from modules.corr import CorrBlock
from torchvision import transforms
from midas.omnidata import OmnidataModel

class MotionFilter:
    """ This class is used to filter incoming frames, extract features, and extract depth and normal """

    def __init__(self, net, video, config, disable_mono, device="cuda:0"):

        # split net modules
        self.cnet = net.cnet
        self.fnet = net.fnet
        self.update = net.update
        self.disable_mono = disable_mono
        self.video = video
        self.mono_model = config.get("mono_model", "omnidata")
        
        self.thresh = config["Tracking"]["motion_filter"]["thresh"]
        self.init_thresh = config["Tracking"]["motion_filter"]["init_thresh"] if "init_thresh" in config["Tracking"]["motion_filter"] else self.thresh
        self.device = device

        self.count = 0
        self.omni_dep = None
        self.moge_model = None
        self.deltas = [0]

        # mean, std for image normalization
        self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None]
        self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None]
        
    @torch.cuda.amp.autocast(enabled=True)
    def context_encoder(self, image):
        """ context features """
        net, inp = self.cnet(image).split([128,128], dim=2)
        return net.tanh().squeeze(0), inp.relu().squeeze(0)

    @torch.cuda.amp.autocast(enabled=True)
    def feature_encoder(self, image):
        """ features for correlation volume """
        return self.fnet(image).squeeze(0)

    @torch.cuda.amp.autocast(enabled=True)
    @torch.no_grad()
    def prior_extractor(self, im_tensor):
        if self.disable_mono:
            return None, None
        input_size = im_tensor.shape[-2:]
        if self.mono_model=='omnidata':
            trans_totensor = transforms.Compose([transforms.Resize((512, 512), antialias=True)])
            im_tensor = trans_totensor(im_tensor).cuda()
            if self.omni_dep is None:
                self.omni_dep = OmnidataModel('depth', 'pretrained_models/omnidata_dpt_depth_v2.ckpt', device="cuda:0")
                self.omni_normal = OmnidataModel('normal', 'pretrained_models/omnidata_dpt_normal_v2.ckpt', device="cuda:0")
            depth = self.omni_dep(im_tensor)[None] * 50
            depth = F.interpolate(depth, input_size, mode='bicubic')
            depth = depth.float().squeeze()
            normal = self.omni_normal(im_tensor) * 2.0 - 1.0
            normal = F.interpolate(normal, input_size, mode='bicubic')
            normal = normal.float().squeeze()
        elif self.mono_model=='moge':
            from moge.model.v2 import MoGeModel # Let's try MoGe-2

            if self.moge_model is None:
                device = torch.device("cuda:0")
                # Load the model from huggingface hub
                self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)                             
            # Infer 
            output = self.moge_model.infer(im_tensor)
            depth = output["depth"]
            normal = output["normal"] #no need to e.g. * 2.0 - 1.0
            depth = depth.float().squeeze()
            normal = normal.permute(0, 3, 1, 2)  
            normal = normal.float().squeeze()

        return depth, normal

    @torch.cuda.amp.autocast(enabled=True)
    @torch.no_grad()
    def track(self, t, tstamp, image, intrinsics=None, is_last=False, pose=None, gt_depth=None):
        """ main update operation - run on every frame in video """

        ht = image.shape[-2] // 8
        wd = image.shape[-1] // 8
        intrinsics[:,:4] /= 8.0

        # normalize images
        inputs = image[None,:, [2,1,0]].to(self.device) / 255.0  # same in droid-slam, just following it, BGR to RGB
        if self.mono_model=='moge': # for moge, inside its model, already a normalization
            inputs_for_prior = inputs.clone()
        inputs = inputs.sub_(self.MEAN).div_(self.STDV) # same in droid-slam, just following it
        if self.mono_model=='omnidata':
            inputs_for_prior = inputs
        # extract features
        gmap = self.feature_encoder(inputs)

        ### always add first frame to the depth video ###
        if self.video.counter.value == 0:
            depth, normal = self.prior_extractor(inputs_for_prior[0])
            net, inp = self.context_encoder(inputs[:,[0]])
            self.net, self.inp, self.fmap = net, inp, gmap
            self.video.append(t, image[0], pose, 1.0, depth, normal, intrinsics, gmap, net[0], inp[0], tstamp)
        ### only add new frame if there is enough motion ###
        else:                
            # index correlation volume
            coords0 = pops.coords_grid(ht, wd, device=self.device)[None,None]
            corr = CorrBlock(self.fmap[None,[0]], gmap[None,[0]])(coords0)

            # approximate flow magnitude using 1 update iteration
            _, delta, weight = self.update(self.net[None], self.inp[None], corr)

            self.deltas.append(delta.norm(dim=-1).mean().item())

            # NOTE: can consider before initialization use larger thresh ( use init_thresh)            
            # NOTE: (tstamp - self.video.kf_stamps[self.video.counter.value-1]) > 3 for imu, dT cannot be too long since IMU integration cannot be too long, so criteria_bamfslam more suitabl for IMU. e.g. sensor always stay in same place
            if delta.norm(dim=-1).mean().item() > self.thresh or (tstamp - self.video.kf_stamps[self.video.counter.value-1]) > 3:
                self.count = 0
                net, inp = self.context_encoder(inputs[:,[0]])
                self.net, self.inp, self.fmap = net, inp, gmap
                depth, normal = self.prior_extractor(inputs_for_prior[0])
                self.video.append(t, image[0], pose, None, depth, normal, intrinsics, gmap, net[0], inp[0], tstamp)
            else:
                self.count += 1
                    
